{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Working directory: /home/xingyao6/llm-agent\n"
     ]
    }
   ],
   "source": [
    "from glob import glob\n",
    "import json\n",
    "import os\n",
    "import re\n",
    "import sys\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "from tqdm import tqdm\n",
    "from typing import List, Dict, Mapping, Tuple, Union, Optional\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "tqdm.pandas()\n",
    "\n",
    "ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd()))) # Should be your path to the repo `mint`\n",
    "sys.path.insert(0, ROOT_DIR)\n",
    "os.chdir(ROOT_DIR)\n",
    "print(f\"Working directory: {os.getcwd()}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tokenizer AND Special Symbols"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "32002"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-2-7b-hf\")\n",
    "\n",
    "# follow https://github.com/openai/openai-python/blob/main/chatml.md\n",
    "# and https://huggingface.co/OpenAssistant/codellama-13b-oasst-sft-v10\n",
    "INST_START = \"<|im_start|>\"\n",
    "INST_END = \"<|im_end|>\"\n",
    "\n",
    "\"\"\"\n",
    "<|im_start|>system\n",
    "{system_message}<|im_end|>\n",
    "<|im_start|>user\n",
    "{prompt}<|im_end|>\n",
    "<|im_start|>assistant\n",
    "\"\"\"\n",
    "\n",
    "# add the special tokens\n",
    "tokenizer.add_special_tokens({\"additional_special_tokens\": [INST_START, INST_END]})\n",
    "len(tokenizer)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analyze Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "SYSTEM_MSG_TO_REPLACE = \"\"\"At each turn, you should first provide your step-by-step thinking for solving the task. Your thought process should be enclosed using \"<thought>\" tag, for example: <thought> I need to print \"Hello World!\" </thought>.\n",
    "\n",
    "After that, you have two options:\"\"\"\n",
    "\n",
    "SYSTEM_MSG_NEW = \"\"\"At each turn, you should first provide your step-by-step thinking for solving the task. After that, you have two options:\"\"\"\n",
    "\n",
    "def convert_state_to_traj(state) -> List[Mapping[str, str]]:\n",
    "    STRIP_WORDS = [\n",
    "        \"Assistant:\",\n",
    "    ]\n",
    "    history = state[\"history\"]\n",
    "    res = []\n",
    "\n",
    "    for i, turn in enumerate(history):\n",
    "        role = turn[\"role\"]\n",
    "        text = turn[\"content\"].strip()\n",
    "        for word in STRIP_WORDS:\n",
    "            text = text.lstrip(word)\n",
    "        text = text.strip()\n",
    "        if i == 0:\n",
    "            assert role == \"user\"\n",
    "            # specifically handle the first turn\n",
    "            splited_text = text.split(\"\\n---\\n\\nTask:\")\n",
    "            assert len(splited_text) == 3, f\"Expecting 3 parts. But got {len(splited_text)} parts: \\n{text}\"\n",
    "            system_message, in_context_example, task = splited_text\n",
    "            system_message = system_message.replace(SYSTEM_MSG_TO_REPLACE, SYSTEM_MSG_NEW)\n",
    "            # res += f\"{INST_START}system\\n{system_message.strip()}{INST_END}\\n\"\n",
    "            res.append({\n",
    "                \"role\": \"system\",\n",
    "                \"content\": system_message.strip()\n",
    "            })\n",
    "            # res += f\"{INST_START}user\\nTask:\\n{task.strip()}{INST_END}\\n\"\n",
    "            res.append({\n",
    "                \"role\": \"user\",\n",
    "                \"content\": f\"Task:\\n{task.strip()}\"\n",
    "            })\n",
    "\n",
    "        elif role == \"user\":\n",
    "            res.append({\n",
    "                \"role\": \"user\",\n",
    "                \"content\": text\n",
    "            })\n",
    "        elif role == \"assistant\":\n",
    "            # replace <thought> and </thought> with empty string, but keep the content between them\n",
    "            # if is on a separate line, remove the line\n",
    "            # do not capture the space right after <thought> and before </thought>\n",
    "            text = re.sub(r\"<thought>(.*?)</thought>\", lambda match: match.group(1).strip(), text, flags=re.DOTALL)\n",
    "\n",
    "            res.append({\n",
    "                \"role\": \"assistant\",\n",
    "                \"content\": text\n",
    "            })\n",
    "    return res\n",
    "\n",
    "def format_traj_to_str(traj: List[Mapping[str, str]]) -> str:\n",
    "    res = \"\"\n",
    "    for turn in traj:\n",
    "        res += f\"{INST_START}{turn['role']}\\n{turn['content']}{INST_END}\\n\"\n",
    "    return res\n",
    "\n",
    "def visualize_traj(traj: List[Mapping[str, str]]):\n",
    "    print(\"==========================\")\n",
    "    for turn in traj:\n",
    "        if turn[\"role\"] == \"user\":\n",
    "            print(\"\\033[1;34;40m\" + f\"USER:\\n{turn['content']}\" + \"\\033[0m\")\n",
    "            print(\"==========================\")\n",
    "        elif turn[\"role\"] == \"assistant\":\n",
    "            # green for assistant\n",
    "            print(\"\\033[1;32;40m\" + f\"ASSISTANT:\\n{turn['content']}\" + \"\\033[0m\")\n",
    "            print(\"==========================\")\n",
    "        elif turn[\"role\"] == \"system\":\n",
    "            # yellow for system\n",
    "            print(\"\\033[1;33;40m\" + f\"SYSTEM:\\n{turn['content']}\" + \"\\033[0m\")\n",
    "            print(\"==========================\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Token Count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def analyze_dataset(df, max_tokens=4096):\n",
    "    # print(\"Coverting conversations to string...\")\n",
    "    conv_str = df[\"conversations\"].apply(format_traj_to_str)\n",
    "    # print(\"Tokenizing...\")\n",
    "    output_traj_length = conv_str.progress_apply(lambda x: len(tokenizer(x)[\"input_ids\"])).rename(\"token_length\")\n",
    "\n",
    "    # plot the distribution of the length\n",
    "    # print(output_traj_length.describe())\n",
    "    # Use seaborn to plot the distribution (ecdf)\n",
    "    # sns.ecdfplot(output_traj_length)\n",
    "\n",
    "    # Cap the length to max_tokens\n",
    "    output_traj_length = output_traj_length.clip(upper=max_tokens)\n",
    "    # Print the sum of tokens\n",
    "    print(f\"Total number of tokens: {output_traj_length.sum():,}\")\n",
    "\n",
    "    return output_traj_length\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['oct24_full5545.jsonl',\n",
       " 'oct20_apps354.jsonl',\n",
       " 'openorca.n50000.jsonl',\n",
       " 'openorca.n10000.jsonl',\n",
       " 'agent_instruct.jsonl',\n",
       " 'sharegpt.n10000.jsonl',\n",
       " 'openorca.n30000.jsonl',\n",
       " 'oct30_easy8155.jsonl',\n",
       " 'evolcode.n10000.jsonl',\n",
       " 'sharegpt.jsonl',\n",
       " 'nov2_gpt4hard411.jsonl',\n",
       " 'oct28_full6728.jsonl',\n",
       " 'sharegpt_gpt4.jsonl']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.listdir(\"data/datasets/\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/9 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Analyzing dataset openorca.n50000.jsonl\n",
      "Number of conversations: 50000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50000/50000 [00:32<00:00, 1555.34it/s]\n",
      " 11%|█         | 1/9 [00:33<04:26, 33.35s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of tokens: 14,028,347\n",
      "- Analyzing dataset openorca.n10000.jsonl\n",
      "Number of conversations: 10000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [00:05<00:00, 1775.18it/s]\n",
      " 22%|██▏       | 2/9 [00:39<02:00, 17.17s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of tokens: 2,813,879\n",
      "- Analyzing dataset openorca.n30000.jsonl\n",
      "Number of conversations: 30000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 30000/30000 [00:18<00:00, 1614.55it/s]\n",
      " 33%|███▎      | 3/9 [00:58<01:48, 18.04s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of tokens: 8,424,516\n",
      "- Analyzing dataset agent_instruct.jsonl\n",
      "Number of conversations: 1866\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1866/1866 [00:05<00:00, 345.95it/s]\n",
      " 44%|████▍     | 4/9 [01:03<01:05, 13.10s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of tokens: 2,501,909\n",
      "- Analyzing dataset sharegpt.n10000.jsonl\n",
      "Number of conversations: 10000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [00:42<00:00, 234.36it/s]\n",
      " 56%|█████▌    | 5/9 [01:47<01:35, 23.99s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of tokens: 17,932,833\n",
      "- Analyzing dataset sharegpt.jsonl\n",
      "Number of conversations: 39537\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 39537/39537 [03:04<00:00, 214.49it/s]\n",
      " 67%|██████▋   | 6/9 [04:54<03:58, 79.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of tokens: 70,947,803\n",
      "- Analyzing dataset sharegpt_gpt4.jsonl\n",
      "Number of conversations: 4583\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4583/4583 [00:51<00:00, 88.80it/s]\n",
      " 78%|███████▊  | 7/9 [05:46<02:21, 70.67s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of tokens: 10,868,124\n",
      "- Analyzing dataset nov2_gpt4hard411.jsonl\n",
      "Number of conversations: 411\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 411/411 [00:01<00:00, 299.74it/s]\n",
      " 89%|████████▉ | 8/9 [05:48<00:48, 48.64s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of tokens: 592,701\n",
      "- Analyzing dataset oct24_full5545.jsonl\n",
      "Number of conversations: 5545\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5545/5545 [00:17<00:00, 325.31it/s]\n",
      "100%|██████████| 9/9 [06:05<00:00, 40.63s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of tokens: 8,251,072\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "dataset_stats = {}\n",
    "dataset_dir = 'data/datasets'\n",
    "DATASET_NAMES = [\n",
    "    'openorca.n50000.jsonl',\n",
    "    'openorca.n10000.jsonl',\n",
    "    'openorca.n30000.jsonl',\n",
    "\n",
    "    'agent_instruct.jsonl',\n",
    "\n",
    "    'sharegpt.n10000.jsonl',\n",
    "    'sharegpt.jsonl',\n",
    "    'sharegpt_gpt4.jsonl',\n",
    "    \n",
    "    'nov2_gpt4hard411.jsonl',\n",
    "    'oct24_full5545.jsonl',\n",
    "]\n",
    "\n",
    "for dataset in tqdm(DATASET_NAMES):\n",
    "    ds_path = os.path.join(dataset_dir, dataset)\n",
    "    print(f'- Analyzing dataset {dataset}')\n",
    "    df = pd.read_json(ds_path, lines=True, orient=\"records\")\n",
    "    if \"conversations\" not in df.columns:\n",
    "        conversations = []\n",
    "        with open(ds_path) as f:\n",
    "            for line in f:\n",
    "                conversations.append(json.loads(line))\n",
    "        df = pd.Series(conversations).to_frame().reset_index().rename(columns={\"index\": \"id\", 0: \"conversations\"})\n",
    "    print(f\"Number of conversations: {len(df)}\")\n",
    "    dataset_stats[dataset] = analyze_dataset(df)\n",
    "\n",
    "\n",
    "dataset_stats_df = []\n",
    "\n",
    "for dataset_name, seq_lens in dataset_stats.items():\n",
    "    num_examples = len(seq_lens)\n",
    "    num_tokens = seq_lens.clip(upper=4096).sum()\n",
    "    dataset_stats_df.append({\n",
    "        \"dataset\": dataset_name,\n",
    "        \"num_examples\": num_examples,\n",
    "        \"num_tokens\": num_tokens,\n",
    "        \"avg_tokens\": num_tokens / num_examples\n",
    "    })\n",
    "\n",
    "dataset_stats_df = pd.DataFrame(dataset_stats_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_74020_row0_col1, #T_74020_row1_col2, #T_74020_row6_col3 {\n",
       "  background-color: #08306b;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_74020_row0_col2 {\n",
       "  background-color: #d2e3f3;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_74020_row0_col3, #T_74020_row2_col3, #T_74020_row3_col3, #T_74020_row8_col1, #T_74020_row8_col2 {\n",
       "  background-color: #f7fbff;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_74020_row1_col1 {\n",
       "  background-color: #1967ad;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_74020_row1_col3, #T_74020_row4_col3 {\n",
       "  background-color: #2777b8;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_74020_row2_col1 {\n",
       "  background-color: #4b98ca;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_74020_row2_col2 {\n",
       "  background-color: #e1edf8;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_74020_row3_col1, #T_74020_row4_col1 {\n",
       "  background-color: #d1e2f3;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_74020_row3_col2 {\n",
       "  background-color: #f1f7fd;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_74020_row4_col2 {\n",
       "  background-color: #c7dbef;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_74020_row5_col1 {\n",
       "  background-color: #e3eef8;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_74020_row5_col2 {\n",
       "  background-color: #e2edf8;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_74020_row5_col3 {\n",
       "  background-color: #529dcc;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_74020_row6_col1 {\n",
       "  background-color: #e7f0fa;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_74020_row6_col2 {\n",
       "  background-color: #dae8f6;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_74020_row7_col1 {\n",
       "  background-color: #f2f7fd;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_74020_row7_col2 {\n",
       "  background-color: #f2f8fd;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_74020_row7_col3 {\n",
       "  background-color: #69add5;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_74020_row8_col3 {\n",
       "  background-color: #58a1cf;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_74020\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_74020_level0_col0\" class=\"col_heading level0 col0\" >dataset</th>\n",
       "      <th id=\"T_74020_level0_col1\" class=\"col_heading level0 col1\" >num_examples</th>\n",
       "      <th id=\"T_74020_level0_col2\" class=\"col_heading level0 col2\" >num_tokens</th>\n",
       "      <th id=\"T_74020_level0_col3\" class=\"col_heading level0 col3\" >avg_tokens</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_74020_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
       "      <td id=\"T_74020_row0_col0\" class=\"data row0 col0\" >openorca.n50000.jsonl</td>\n",
       "      <td id=\"T_74020_row0_col1\" class=\"data row0 col1\" >50,000</td>\n",
       "      <td id=\"T_74020_row0_col2\" class=\"data row0 col2\" >14,028,347</td>\n",
       "      <td id=\"T_74020_row0_col3\" class=\"data row0 col3\" >280.57</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_74020_level0_row1\" class=\"row_heading level0 row1\" >5</th>\n",
       "      <td id=\"T_74020_row1_col0\" class=\"data row1 col0\" >sharegpt.jsonl</td>\n",
       "      <td id=\"T_74020_row1_col1\" class=\"data row1 col1\" >39,537</td>\n",
       "      <td id=\"T_74020_row1_col2\" class=\"data row1 col2\" >70,947,803</td>\n",
       "      <td id=\"T_74020_row1_col3\" class=\"data row1 col3\" >1794.47</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_74020_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n",
       "      <td id=\"T_74020_row2_col0\" class=\"data row2 col0\" >openorca.n30000.jsonl</td>\n",
       "      <td id=\"T_74020_row2_col1\" class=\"data row2 col1\" >30,000</td>\n",
       "      <td id=\"T_74020_row2_col2\" class=\"data row2 col2\" >8,424,516</td>\n",
       "      <td id=\"T_74020_row2_col3\" class=\"data row2 col3\" >280.82</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_74020_level0_row3\" class=\"row_heading level0 row3\" >1</th>\n",
       "      <td id=\"T_74020_row3_col0\" class=\"data row3 col0\" >openorca.n10000.jsonl</td>\n",
       "      <td id=\"T_74020_row3_col1\" class=\"data row3 col1\" >10,000</td>\n",
       "      <td id=\"T_74020_row3_col2\" class=\"data row3 col2\" >2,813,879</td>\n",
       "      <td id=\"T_74020_row3_col3\" class=\"data row3 col3\" >281.39</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_74020_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n",
       "      <td id=\"T_74020_row4_col0\" class=\"data row4 col0\" >sharegpt.n10000.jsonl</td>\n",
       "      <td id=\"T_74020_row4_col1\" class=\"data row4 col1\" >10,000</td>\n",
       "      <td id=\"T_74020_row4_col2\" class=\"data row4 col2\" >17,932,833</td>\n",
       "      <td id=\"T_74020_row4_col3\" class=\"data row4 col3\" >1793.28</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_74020_level0_row5\" class=\"row_heading level0 row5\" >8</th>\n",
       "      <td id=\"T_74020_row5_col0\" class=\"data row5 col0\" >oct24_full5545.jsonl</td>\n",
       "      <td id=\"T_74020_row5_col1\" class=\"data row5 col1\" >5,545</td>\n",
       "      <td id=\"T_74020_row5_col2\" class=\"data row5 col2\" >8,251,072</td>\n",
       "      <td id=\"T_74020_row5_col3\" class=\"data row5 col3\" >1488.02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_74020_level0_row6\" class=\"row_heading level0 row6\" >6</th>\n",
       "      <td id=\"T_74020_row6_col0\" class=\"data row6 col0\" >sharegpt_gpt4.jsonl</td>\n",
       "      <td id=\"T_74020_row6_col1\" class=\"data row6 col1\" >4,583</td>\n",
       "      <td id=\"T_74020_row6_col2\" class=\"data row6 col2\" >10,868,124</td>\n",
       "      <td id=\"T_74020_row6_col3\" class=\"data row6 col3\" >2371.40</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_74020_level0_row7\" class=\"row_heading level0 row7\" >3</th>\n",
       "      <td id=\"T_74020_row7_col0\" class=\"data row7 col0\" >agent_instruct.jsonl</td>\n",
       "      <td id=\"T_74020_row7_col1\" class=\"data row7 col1\" >1,866</td>\n",
       "      <td id=\"T_74020_row7_col2\" class=\"data row7 col2\" >2,501,909</td>\n",
       "      <td id=\"T_74020_row7_col3\" class=\"data row7 col3\" >1340.79</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_74020_level0_row8\" class=\"row_heading level0 row8\" >7</th>\n",
       "      <td id=\"T_74020_row8_col0\" class=\"data row8 col0\" >nov2_gpt4hard411.jsonl</td>\n",
       "      <td id=\"T_74020_row8_col1\" class=\"data row8 col1\" >411</td>\n",
       "      <td id=\"T_74020_row8_col2\" class=\"data row8 col2\" >592,701</td>\n",
       "      <td id=\"T_74020_row8_col3\" class=\"data row8 col3\" >1442.09</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7fd9481e6970>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_stats_df\\\n",
    "    .sort_values(\"num_examples\", ascending=False)\\\n",
    "    .style.format({\n",
    "    \"num_examples\": \"{:,}\",\n",
    "    \"num_tokens\": \"{:,}\",\n",
    "    \"avg_tokens\": \"{:.2f}\",\n",
    "}).background_gradient(subset=[\"num_examples\", \"num_tokens\", \"avg_tokens\"], cmap=\"Blues\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-agent",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
